#include "shmhdl.hpp"

#include <atomic>
#include <sys/mman.h>
#include <sys/stat.h>
#include <system_error>
#include <unistd.h>

namespace shm_kernel::shared_memory {

enum O_FLAGS
{
  CREATE_ONLY    = O_RDWR | O_CREAT | O_EXCL,
  OPEN_ONLY      = O_RDWR,
  CREATE_OR_OPEN = O_RDWR | O_CREAT
};

// shared memory object permission
enum PERM
{
  READ  = S_IRUSR | S_IRGRP | S_IROTH,
  WRITE = S_IWUSR | S_IWGRP | S_IWOTH,
  EXEC  = S_IXUSR | S_IXGRP | S_IXOTH
};

shm_handle::shm_handle(std::string_view name, const shmsz_t& nbytes)
{
  int __fd = shm_open(name.data(),
                      O_RDWR | O_CREAT | O_EXCL,
                      PERM::EXEC | PERM::WRITE | PERM::READ);
  if (__fd == -1) {
    char            errmsg[256];
    std::error_code ec(errno, std::system_category());
    snprintf(errmsg,
             256,
             "fail to create shared memory object. (%d) [%s]",
             ec.value(),
             ec.message().c_str());
    throw std::runtime_error(errmsg);
  }
  if (ftruncate(__fd, nbytes + sizeof(shm_meta_t)) == -1) {
    char errmsg[256];
    close(__fd);
    shm_unlink(name.data());
    std::error_code ec(errno, std::system_category());
    snprintf(errmsg,
             256,
             "fail to ftruncate shared memory object. (%d) [%s]",
             ec.value(),
             ec.message().c_str());
    throw std::runtime_error(errmsg);
  }
  fstat(__fd, &this->status_);
  char* pMetaBuf = (char*)mmap(nullptr,
                               sizeof(shm_meta_t),
                               PROT_EXEC | PROT_READ | PROT_WRITE,
                               MAP_SHARED,
                               __fd,
                               0);
  // fail to map
  if (pMetaBuf == (void*)-1) {
    std::error_code ec(errno, std::system_category());
    close(__fd);
    __fd = -1;
    shm_unlink(name.data());

    char errmsg[256];
    snprintf(errmsg,
             256,
             "fail to map shared memory meta. (%d). [%s]",
             ec.value(),
             ec.message().c_str());
    throw std::runtime_error(errmsg);
  }
  this->meta_              = new (pMetaBuf) shm_meta_t;
  this->meta_->ref_count_  = 1;
  this->meta_->shmsz_      = nbytes + sizeof(shm_meta_t);
  this->meta_->shm_status_ = SHM_STATUS::OK;

  this->fd_   = __fd;
  this->name_ = { name.begin(), name.end() };
  this->addr_ = nullptr;
}

shm_handle::shm_handle(std::string_view name)
{
  int __fd = shm_open(
    name.data(), O_FLAGS::OPEN_ONLY, PERM::EXEC | PERM::WRITE | PERM::READ);
  if (__fd == -1) {
    char            errmsg[256];
    std::error_code ec(errno, std::system_category());
    snprintf(errmsg,
             256,
             "fail to open shared memory object. (%d) [%s]",
             ec.value(),
             ec.message().c_str());
    throw std::runtime_error(errmsg);
  }
  fstat(__fd, &this->status_);
  char* pMetaBuf = (char*)mmap(nullptr,
                               sizeof(shm_meta_t),
                               PROT_WRITE | PROT_EXEC | PROT_READ,
                               MAP_SHARED,
                               __fd,
                               0);
  if (pMetaBuf == (void*)-1) {
    char            errmsg[256];
    std::error_code ec(errno, std::system_category());
    snprintf(errmsg,
             256,
             "fail to map shared memory. (%d) [%s]",
             ec.value(),
             ec.message().c_str());
    throw std::runtime_error(errmsg);
  }
  this->meta_ = reinterpret_cast<shm_meta_t*>(pMetaBuf);
  this->meta_->mtx_.lock();
  this->meta_->ref_count_ += 1;
  this->meta_->mtx_.unlock();

  this->fd_   = __fd;
  this->name_ = { name.begin(), name.end() };
  this->addr_ = nullptr;
}

shm_handle::shm_handle(shm_handle&& handle) noexcept
{
  this->meta_   = handle.meta_;
  this->fd_     = handle.fd();
  this->name_   = std::move(handle.name_);
  this->status_ = std::move(handle.status_);
}

shm_handle::~shm_handle()
{
  // fd_ < 0 means unlink is alrady called.
  if (fd_ > 0) {
    std::lock_guard<std::mutex> __g(this->meta_->mtx_);
    this->meta_->ref_count_ -= 1;
    if (this->meta_->ref_count_ == 0)
      shm_unlink(name_.c_str());
    this->unmap();
    close(fd_);
    fd_ = -1;
  }
}

void
shm_handle::update_status() noexcept
{
  //
  fstat(fd_, &status_);
}

void
shm_handle::update_status(std::error_code& ec) noexcept
{
  ec.clear();
  if (fstat(fd_, &status_) == -1) {
    ec.assign(errno, std::system_category());
  }
}

void*
shm_handle::map() noexcept
{
  if (this->addr_ == nullptr) {
    this->addr_ = mmap(nullptr,
                       //  status_.st_size,
                       this->meta_->shmsz_,
                       PROT_READ | PROT_WRITE | PROT_EXEC,
                       MAP_SHARED,
                       fd_,
                       0);

    this->addr_ = reinterpret_cast<char*>(this->addr_) + sizeof(shm_meta_t);
    return addr_;
  } else {
    // already mapped, just return addr_
    return addr_;
  }
}

void*
shm_handle::map(void* addr) noexcept
{
  // !NOT IMPL YET
  return nullptr;
  if (this->addr_ == nullptr) {
    this->addr_ = mmap(addr,
                       status_.st_size,
                       PROT_READ | PROT_WRITE | PROT_EXEC,
                       MAP_SHARED,
                       fd_,
                       0);
    this->addr_ = (char*)this->addr_ + sizeof(std::mutex) + sizeof(size_t);
    return addr_;
  }
  if (this->addr_ != addr) {
    munmap((char*)addr_ - sizeof(std::mutex) - sizeof(size_t), status_.st_size);
    this->addr_ = mmap(addr,
                       status_.st_size,
                       PROT_READ | PROT_WRITE | PROT_EXEC,
                       MAP_SHARED,
                       fd_,
                       0);
    this->addr_ = (char*)this->addr_ + sizeof(std::mutex) + sizeof(size_t);
    return addr_;
  }
  return addr_;
}

void*
shm_handle::map(std::error_code& ec) noexcept
{
  ec.clear();
  if (this->addr_ == nullptr) {
    this->addr_ = mmap(nullptr,
                       this->meta_->shmsz_,
                       PROT_READ | PROT_WRITE | PROT_EXEC,
                       MAP_SHARED,
                       fd_,
                       0);
    if (this->addr_ == (void*)-1) {
      ec.assign(errno, std::system_category());
      return nullptr;
    }
    this->addr_ = (char*)this->addr_ + sizeof(shm_meta_t);
  }
  return this->addr_;
}

void*
shm_handle::map(void* addr, std::error_code& ec) noexcept
{
  // !NOT IMPL YET
  return nullptr;
  ec.clear();
  if (this->addr_ == nullptr) {
    this->addr_ = mmap(addr,
                       this->meta_->shmsz_,
                       PROT_READ | PROT_WRITE | PROT_EXEC,
                       MAP_SHARED,
                       fd_,
                       0);
    if (this->addr_ == (void*)-1) {
      ec.assign(errno, std::system_category());
      return nullptr;
    }
    this->addr_ = (char*)this->addr_ + sizeof(shm_meta_t);
    return this->addr_;
  }

  if (this->addr_ == addr) {
    return addr_;
  }

  if (munmap((char*)addr_ - sizeof(shm_meta_t), status_.st_size) == -1) {
    ec.assign(errno, std::system_category());
    return nullptr;
  }
  this->addr_ = mmap(addr,
                     this->meta_->shmsz_,
                     PROT_READ | PROT_WRITE | PROT_EXEC,
                     MAP_SHARED,
                     fd_,
                     0);
  if (this->addr_ == (void*)-1) {
    ec.assign(errno, std::system_category());
    return nullptr;
  }
  this->addr_ = (char*)this->addr_ + sizeof(shm_meta_t);
  return addr_;
}

void
shm_handle::unmap() noexcept
{
  if (this->addr_) {
    munmap((char*)addr_ - sizeof(shm_meta_t), this->meta_->shmsz_);
    this->addr_ = nullptr;
  }
}

void
shm_handle::unmap(std::error_code& ec) noexcept
{
  if (this->addr_) {
    if (munmap((char*)addr_ - sizeof(shm_meta_t), this->meta_->shmsz_) == -1) {
      ec.assign(errno, std::system_category());
      return;
    }
    this->addr_ = nullptr;
  }
}

void
shm_handle::unlink() noexcept
{
  if (this->fd_ != -1) {
    shm_unlink(name_.c_str());
    fd_ = -1;
  }
}

void
shm_handle::unlink(std::error_code& ec) noexcept
{
  if (this->fd_ != -1) {
    int rv    = shm_unlink(name_.c_str());
    this->fd_ = -1;
    if (rv == -1) {
      ec.assign(errno, std::system_category());
      return;
    }
  }
}
const shmsz_t
shm_handle::nbytes() noexcept
{
  return status_.st_size - sizeof(std::mutex) + sizeof(size_t);
}

const int
shm_handle::fd() noexcept
{
  return this->fd_;
}

std::string_view
shm_handle::name() noexcept
{
  return this->name_;
}

const shm_desc_t&
shm_handle::status() noexcept
{
  return this->status_;
}

void*
shm_handle::addr() noexcept
{
  return this->addr_;
}

const size_t&
shm_handle::ref_count() noexcept
{
  return this->meta_->ref_count_;
}

} // namespace libshm